{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UkJ8Xh2gUknk"
      },
      "source": [
        "\n",
        "Copyright 2021 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "     https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "engAKP-iifhM"
      },
      "source": [
        "# Loading precomputed Myrtle-10 kernels from CIFAR-10\n",
        "\n",
        "Loading full CIFAR-10 kernels requires large RAM (+32GB)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pmf9KNv3xQop"
      },
      "outputs": [],
      "source": [
        "# Imports\n",
        "from tensorflow.io import gfile\n",
        "import os\n",
        "import numpy as np\n",
        "import itertools\n",
        "import concurrent.futures\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 1457,
          "status": "ok",
          "timestamp": 1618254012591,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "hV1V1Fm93wUX",
        "outputId": "a30f914e-f30d-414a-d987-aa271b6925c8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Colab authentication failed!!\n",
            "Corruption directory names: ['brightness_1', 'brightness_2', 'brightness_3', 'brightness_4', 'brightness_5', 'clean', 'contrast_1']\n"
          ]
        }
      ],
      "source": [
        "# External\n",
        "from google.colab import auth\n",
        "try:\n",
        "  auth.authenticate_user()\n",
        "except:\n",
        "  print('Colab authentication failed!!')\n",
        "\n",
        "KERNEL_PATH = 'gs://neural-tangents-kernels/infinite-uncertainty/kernels/myrtle-10'\n",
        "\n",
        "\n",
        "print('Corruption directory names:', gfile.listdir(KERNEL_PATH)[:7])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TuMAIMYZWDuL"
      },
      "source": [
        "# Loading computed kernels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LPxZa8LbQeCc"
      },
      "outputs": [],
      "source": [
        "# Convenient packaging for kernels and labels\n",
        "class LoadedKernel:\n",
        "\n",
        "  def __init__(self,\n",
        "               kernel_type,\n",
        "               n_train_load=5000,\n",
        "               kernel_path='/tmp/kernels',\n",
        "               dtype=np.float32):\n",
        "    self.kernel_path = kernel_path\n",
        "    self.k_type = kernel_type\n",
        "\n",
        "    # Load kernels by block size of 5000\n",
        "    assert n_train_load % 5000 == 0\n",
        "    self.n_train_load = n_train_load\n",
        "    self.max_train_id = int(np.ceil(n_train_load / 5000))\n",
        "    self.clean_ids = list(\n",
        "        filter(\n",
        "            lambda x: x[0] \u003c= x[1],\n",
        "            itertools.product(\n",
        "                list(range(self.max_train_id)) + [10, 11],\n",
        "                list(range(self.max_train_id)) + [10, 11])))\n",
        "\n",
        "    self.test_ids = list(\n",
        "        filter(lambda x: x[0] \u003c= x[1],\n",
        "               itertools.product(list(range(self.max_train_id)), [10, 11])))\n",
        "    self.dtype = dtype\n",
        "    self.kernel = np.zeros(\n",
        "        shape=(5000 * (self.max_train_id + 2), 5000 * (self.max_train_id + 2)),\n",
        "        dtype=self.dtype)\n",
        "    self._load_labels()\n",
        "    self._load_kernel()\n",
        "\n",
        "  def _update_kernel(self, test_type, kernel_type, ids):\n",
        "    filedir = os.path.join(self.kernel_path, test_type)\n",
        "    assert gfile.exists(filedir), f\"File path {filedir} doesn't exsit\"\n",
        "    filepath = os.path.join(filedir, kernel_type)\n",
        "\n",
        "    def _update_kernel_from_indicies(index):\n",
        "      row, col = index\n",
        "      # Loading based on 60k x 60k matrix\n",
        "      with gfile.GFile(f'{filepath}-{row}-{col}', 'rb') as f:\n",
        "        val = np.load(f).astype(self.dtype)\n",
        "\n",
        "      # Adjust row/col index based on self.n_train_load\n",
        "      if row \u003e= self.max_train_id:\n",
        "        row += self.max_train_id - 10\n",
        "      if col \u003e= self.max_train_id:\n",
        "        col += self.max_train_id - 10\n",
        "      self.kernel[row * 5000:(row + 1) * 5000,\n",
        "                  col * 5000:(col + 1) * 5000] = val\n",
        "      if col \u003e row:\n",
        "        self.kernel[col * 5000:(col + 1) * 5000,\n",
        "                    row * 5000:(row + 1) * 5000] = val.T\n",
        "\n",
        "    with concurrent.futures.ThreadPoolExecutor(max_workers=200) as executor:\n",
        "      executor.map(_update_kernel_from_indicies, ids)\n",
        "\n",
        "  def _load_kernel(self):\n",
        "    self._update_kernel('clean', self.k_type, self.clean_ids)\n",
        "\n",
        "  def _load_labels(self):\n",
        "    with gfile.GFile(os.path.join(self.kernel_path, 'clean', 'labels'),\n",
        "                     'rb') as f:\n",
        "      self.labels = np.load(f)\n",
        "    self.labels = np.concatenate(\n",
        "        [self.labels[:self.max_train_id * 5000], self.labels[-10000:]])\n",
        "\n",
        "  def update_test(self, test_type):\n",
        "    self._update_kernel(test_type, self.k_type, self.test_ids)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-n6txn1PWZwA"
      },
      "outputs": [],
      "source": [
        "# Check kernels for sanity\n",
        "def kernel_inference(kdd, ktd, y_train, diag_reg):\n",
        "  return ktd.dot(\n",
        "      np.linalg.solve(\n",
        "          kdd + diag_reg * np.mean(np.diag(kdd)) * np.eye(kdd.shape[0]),\n",
        "          y_train))\n",
        "\n",
        "\n",
        "def run_kerenel_inference(kernel, n_train, seed=0):\n",
        "  kernel_train_size = kernel.kernel.shape[0] - 10000\n",
        "  assert kernel_train_size \u003e= n_train\n",
        "  np.random.seed(seed)\n",
        "  support_idx = np.random.choice(kernel_train_size, n_train, replace=False)\n",
        "  train_labels = kernel.labels[:kernel_train_size]\n",
        "  test_labels = kernel.labels[kernel_train_size:]\n",
        "  support_y = train_labels[support_idx]\n",
        "  kdd = kernel.kernel[support_idx, :][:, support_idx]\n",
        "  ktd = kernel.kernel[kernel_train_size:, :][:, support_idx]\n",
        "  accs = []\n",
        "  predictions = []\n",
        "  diag_regs = np.logspace(-8, 2, 20)\n",
        "  for diag_reg in diag_regs:\n",
        "    pred = kernel_inference(kdd, ktd, support_y, diag_reg)\n",
        "    acc = np.mean(np.argmax(pred, axis=-1) == np.argmax(test_labels, axis=-1))\n",
        "    predictions.append(pred)\n",
        "    accs.append(acc)\n",
        "  return np.array(predictions), diag_regs, np.array(accs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 10138,
          "status": "ok",
          "timestamp": 1618254023146,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "0S9pMx1Lc19q",
        "outputId": "0d7a0062-aa54-41f8-c6cd-149bc692017c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CPU times: user 2.51 s, sys: 1.47 s, total: 3.98 s\n",
            "Wall time: 10 s\n"
          ]
        }
      ],
      "source": [
        "%%time\n",
        "N_TRAIN_LOAD = 5000\n",
        "KERNEL_TYPE = 'ntk'  # 'nngp', 'ntk\n",
        "\n",
        "Kernel = LoadedKernel(\n",
        "    n_train_load=N_TRAIN_LOAD,\n",
        "    kernel_type=KERNEL_TYPE,\n",
        "    kernel_path=KERNEL_PATH,\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 265
        },
        "executionInfo": {
          "elapsed": 331,
          "status": "ok",
          "timestamp": 1618254023483,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "NT-QmF7cMCky",
        "outputId": "3eca919d-a46e-407f-9978-f7ef0ae72a63"
      },
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAAVLElEQVR4nO3da4yc1XkH8P9/7nuxvfiyvmAbUmIJm1YllYVS0VSOIiVOvgCq\nkIJQZSlRnaiggBSpQpEq8qVqvuTyJSRyBIIPBGQpoVgFtSArEW2RUBZCg5FDTYHYa2/W2LueXe+u\n5/r0ww5oazznOezMzuxy/j/J8u6cs+857zvvM+/svM8+h2YGEfnky/R7AiLSGwp2kUQo2EUSoWAX\nSYSCXSQRuV4Oli8NWWFoY7APm+FtZCsNd5xmwX8Na+bo9snNh8dq5v1xWPfvdtC5I2JZf66VEX8u\nhbI/F8u6XXz058tGeC61df42CmXnZIkYBwAaA+GdNn8qsIjLZqbu98nOXgm219eXgu2VuSnUr8xd\nc8Y9DfbC0Eb86cEHg31ylfCTM/x22R1nYec6t8/8qL/rm8amg+2VHcPuNgrTFbcPF2rB9vp1A+42\n3rmr6Pa58bnwOABQG3aiPeLEb+QjAnU2HKjnPuc/P7tfCAcGAOSmF9w+5X0jwfaY/an5pwJKU/4L\nz8jxU8H2qYN7gu1vPvejtm0dvY0neZDkWyTfJvlQJ9sSkZW17GAnmQXwYwBfBrAPwD0k93VrYiLS\nXZ1c2W8D8LaZvWNmVQBPA7ijO9MSkW7rJNivB3Bmyffjrcf+H5KHSY6RHKtfmetgOBHpRCfBfq1P\nLT7yCYSZHTGz/Wa2P1ca6mA4EelEJ8E+DmDXku93AjjX2XREZKV0Euy/AbCH5KdIFgB8FcCx7kxL\nRLpt2ffZzaxO8n4A/w4gC+AxM3sz9DNs+vfRm15SRybi9SniXrA7ToSYZIuuiHlJ7lEupGX8nY5J\nAmKzw/MAgEUk70SdL10QdVwynf85ubvPgeaOkmrM7HkAz3eyDRHpDeXGiyRCwS6SCAW7SCIU7CKJ\nULCLJELBLpKInv49e7bS8P8e3bkvmnn/kjvOYM0vcFGYDhcBAAC+PxVsH2j44/Cy//fUVq0G2wvz\ng+42Rl/Z5vYpng7/fT4AFIqFcIecf32wiD4Zp0jD6Kub3G0Uxv394Zx//Nc7eQGW92/6N0p+KOXm\n/HoCzdnLwfaRt8Lt2Svt6wToyi6SCAW7SCIU7CKJULCLJELBLpIIBbtIIhTsIolQsIskoqdJNc1C\nxl/Awfnb/JiEmYVd6/0+m/xEiY3TG4Lt1W3+YhT56bzbhwvhpJraRr9239QtfvGE4fHw/gBAfSh8\nSsSsGNOIWJGnOB1O3pna6w80dNZ/nnNlJ0kIwNzu8AoPjULEIhFDfp9i2T8X1r8dTqAq3xg+F5r/\n0/7Y68oukggFu0giFOwiiVCwiyRCwS6SCAW7SCIU7CKJULCLJKK3STU5Yn40PKS3EkhMhZmYhJnL\nu/zXuQ2nwmMtbPETNmJkC+H5Vjf64+T2zrh9rrwWTh4BgOpQ+Lg0I86YRtHvY9lwgknl0+FKNgBQ\n+W3EQBGrxsxtC+9zTFJNdcSfSr3kn5cbSuF98ubaDBxWXdlFEqFgF0mEgl0kEQp2kUQo2EUSoWAX\nSYSCXSQRCnaRRPQ0qSY338CmMX/JnhBvSSbArzAD+AkzAJD9w2R4G9N+kgrn/eQQq4aXBRo6P+Bu\no/7UdrfP+tfD+wMAVgon8JizPBeAqCWiOBc+LpnaZncbwyci9mfeX/5p61x4rJjlrGKWf8pGLP/U\nuBA+v7e9HJ7re5fbV3LqKNhJvgdgFkADQN3M9neyPRFZOd24sn/ezC50YTsisoL0O7tIIjoNdgPw\nAslXSR6+VgeSh0mOkRyr1uc6HE5ElqvTt/G3m9k5kqMAXiT5ezN7aWkHMzsC4AgAbBjcYR2OJyLL\n1NGV3czOtf4/D+AZALd1Y1Ii0n3LDnaSQyTXffA1gC8CONGtiYlId3XyNn4rgGe4WBwgB+DnZvZv\noR9o5jOo7AjfmzanTsBAw18RJmallpjCE9599NoO/35+7qI/DivhFWEaI/6KMBf3+YURBs/5820M\nOMVFcn4hB4vok3dWapm62T81Byb8FWGys/7xn98Zfp6b+W6tCOPPZejdcE7F7C5nRZhT7c+DZQe7\nmb0D4M+X+/Mi0lu69SaSCAW7SCIU7CKJULCLJELBLpIIBbtIIhTsIonoafEK1g2F6Upn27jsFyPI\nT4dXG4keyyk8EZMwk4koXgEnqSYTUTBieNwvpJG75B+7TMU5dhFzaUYUe8jOhs+D4bMRxUXK/v7E\nFA8pXhwMtscUr8gt+KGUnwk/z4BfyKR0wTlX6s32be7oIvKJoGAXSYSCXSQRCnaRRCjYRRKhYBdJ\nhIJdJBEKdpFE9Dapxgxc8FfFCLGqn5jABb9PtuBXdvESHLwKMwDchJlujVOcaZ9M8eF2rvjbcV/9\n6VdkYcyKMM4+FcsR+xN1bCPOhflwn5ikGjb9WqrZef/ct1rd2YazP4F56MoukggFu0giFOwiiVCw\niyRCwS6SCAW7SCIU7CKJ6Ol9dssS9evCK154Lz+F+XChAQCobfRXUKlujFid43x4rjErtcQUnvDu\nFzdH/MIUszv9vIHh//Xn2xwMF6+wjH+fvZn39zmfD893Zpd/ag6e7s7xr2wOn1Mx+1Mb8vsUyv5z\nVCwVg+1Xtiy/0Iau7CKJULCLJELBLpIIBbtIIhTsIolQsIskQsEukggFu0gieppUUxnJ4J27wkkD\n3svP6Cvb3HGmbvETP3J7Z9w+9ae2B9sv7otIZIlYqcUrPBGTMPPf//CI2+eWgb93+1RHwkUYmnm/\nSENz0C88MXBmfbD9ka/91N3Gt0rfdPsMTWxw+/zxQCPYzlK4HQB2bL3o9nn39Ca3z96JHcH20/eG\ni1tUf6/iFSLJi6hCxMdInid5YsljG0m+SPJU6//rVnaaItKpmCv74wAOXvXYQwCOm9keAMdb34vI\nKuYGu5m9BGDqqofvAPBE6+snANzZ3WmJSLct93f2rWY2AQCt/0fbdSR5mOQYybHG3NwyhxORTq34\nB3RmdsTM9pvZ/uyQ/yeJIrIylhvskyS3A0Dr//Pdm5KIrITlBvsxAIdaXx8C8Gx3piMiK8VNqiH5\nFIADADaTHAfwMIDvAThK8usATgO4O2awQtlw43OdrQhTPD3t9hke9xMprrzmJ7usf30y2D54zh8n\nd2nB7eOt1BJTYSYmYeb6X/ufmTSL4QSeuEo1fp/8THguD8z7CTM7Xiq7fTIz/vEvTYeTXZo5P/es\nNtz2Y6sP3VD2k3NwZiLYvPPozcH2C1Ptj727F2Z2T5umL3g/KyKrhzLoRBKhYBdJhIJdJBEKdpFE\nKNhFEqFgF0mEgl0kET1e/gmoDftVV0IKRX/ZpvqQv1vViOV6rBQeqzHgj5OphJdTAvxXXG9JJsCv\nMAP4CTMAUFvX+SnRKPhJNZl6eL5VP18JjQH/uLDiJ7LUhsPPQCMiSag27PdhI+Lamg8f/+q68Daa\n2fbz0JVdJBEKdpFEKNhFEqFgF0mEgl0kEQp2kUQo2EUS0dP77AAA53akWxwhF3F/POJWfjNizy3j\n3NPM+fdW4WwDAMDwduIKRvj32WO242kUI+YSuNf74VycPo1ixP5EnAsx54s335jzKeZefNT54p4L\n/iba0ZVdJBEKdpFEKNhFEqFgF0mEgl0kEQp2kUQo2EUSoWAXSURvk2pIN/nAS7aISaRoFCL6FN0u\nbkKGRSRJNCPmS6dPMx+RGDLY9PtEJH54hSdiEmbqpc4TTBoDEUk1XUjeAfx9bvj1UtCMOJ9iinqA\nHRbSCDTryi6SCAW7SCIU7CKJULCLJELBLpIIBbtIIhTsIolQsIskoqdJNWwYCrPh5A82w8kUmdkr\n7jjFaT8LwrIRq4nMhcfKl/1xsrMVf5xKNTxO3i+VMnBmvdsnPzPn9vFWaolJUompyJKfqwfbB8/6\nz09+JnzcACBz2T/+A1PhucQkaeUq/j4XL/mJT6iF92lgOrzCDRvtnz93L0g+RvI8yRNLHvsuybMk\nX2/9+4q3HRHpr5i38Y8DOHiNx39oZre2/j3f3WmJSLe5wW5mLwGY6sFcRGQFdfIB3f0kf9d6m39d\nu04kD5McIzlWq/q/M4rIylhusP8EwE0AbgUwAeD77Tqa2REz229m+/OFoWUOJyKdWlawm9mkmTXM\nrAngZwBu6+60RKTblhXsJLcv+fYuACfa9RWR1cG9z07yKQAHAGwmOQ7gYQAHSN4KwAC8B+AbKzdF\nEekGN9jN7J5rPPzocgarrSPOfS48ZNPJHxl9dZM7ztRePwml8mk/OSdT2xwe52Y/J2n4bMntUyyH\nky1mdvnjPPK1n7p9Hpj/ptunuiHcHrMsU0yVGS9p5h//7kl3G/+8cK8/zuSg2+ePtzvVeQbCiSwA\nMDzqf/h8fnyd2+fmU1uC7We+FP752qvt25QuK5IIBbtIIhTsIolQsIskQsEukggFu0giFOwiiehp\n8YpCuYndL4TvbxvD9zwL49PuOENn/UIOld/6S3gMn5gMtg9M+ONkywtuH694xeBp/28KvlXy76Hv\neKns9mkMhO9/x6zIE1Pgwis8EXMPfet/+fuTmZ13+xSnNwbbY1bkqQ/599C3zPj36+1c+Jy74Vjb\nvzkDAExdat+mK7tIIhTsIolQsIskQsEukggFu0giFOwiiVCwiyRCwS6SiJ6vCJObdpJMMuHXH875\nSSq5iJVa4CTvAIDNh8fKzvrjcN4vkmHVcIJJxjkmADA04VSdAJCZiUnwcRI/upRU463UElN0IiZh\nhgv+ijCFS+E+MYlEuQU/lHJzNbePdy4Up8Nz7WhFGBH5ZFCwiyRCwS6SCAW7SCIU7CKJULCLJELB\nLpIIBbtIInqaVNMYyKK8b6SjbayPSNiY2z3s99nmv85tnQuvCDO/0x+neNFPDsnOhxMpKpsjVjU5\n4FdBKU37q+nUhsPHpRlx/BsFv8/AVD3Y7q3SAvgVZgA/YQYALv5ZuMqMRVwSqxsi9vmCXx1p03vh\n6kfv7wtXLWqcaj9ZXdlFEqFgF0mEgl0kEQp2kUQo2EUSoWAXSYSCXSQRPb3PbgQaef9+ZHAb+azb\nJ+Y+b0wfr2hBM2JfolZQccfxt8GSf5+9mfOfbu/5Mf/woxFRO6RRcPZ5IGJ/Io5L1PF3utSHIs4n\n/xZ61HFhNjwZbxsWmKqu7CKJcIOd5C6SvyJ5kuSbJB9oPb6R5IskT7X+Dy9CJSJ9FXNlrwP4tpnt\nBfBZAPeR3AfgIQDHzWwPgOOt70VklXKD3cwmzOy11tezAE4CuB7AHQCeaHV7AsCdKzRHEemCj/U7\nO8kbAXwGwCsAtprZBLD4ggBgtM3PHCY5RnKsfmWuw+mKyHJFBzvJYQC/APCgmc3E/pyZHTGz/Wa2\nP1fy1xkXkZURFewk81gM9CfN7JethydJbm+1bwdwfmWmKCLdEPNpPAE8CuCkmf1gSdMxAIdaXx8C\n8Gz3pyci3RKTVHM7gL8F8AbJ11uPfQfA9wAcJfl1AKcB3O1tyDJAzan3YJlwAkOj5E+5FpEEUR1x\nu7hjxYwTs1IIm+1X8Vgcx38DtmPrRbdPbfiaH6tc1cc5/hGJRM2IBJNcJbyd4VH/8536ULjoBBB3\n/L3CEzEJM5VNTbdPthrxRjqfDzZ7cw0lPblHwsz+E0C7Eb7g/byIrA7KoBNJhIJdJBEKdpFEKNhF\nEqFgF0mEgl0kEQp2kUT0tFJNpg6UpsIJJJYJt+fmau44xXI4MQEA6iW/5ErWGatY9kuP5GfCq70A\nQHY+PE6h7M/13dP+ai83lP3qL2w4FWRy3akCVLwUTkI5P+4nzGyZ8fcn5nzxVmqJqTATkzBTuBQ+\ntwHAquHzpXQxvA0GFtrRlV0kEQp2kUQo2EUSoWAXSYSCXSQRCnaRRCjYRRKhYBdJBM38G/3dsiG/\nxf7yur/paBvN2ctun8zgoNuHJb/8SOPCVHicoQF3G1b1kzqsFsiEQNxcuXuH2wdnJvw+eSfPihHL\ndzHiGlJzko22bnE3Yecm/T5OkgoAZNavD7Z7SzIBcCvMxM6lMRku5ZjbtjXY/vKFoyhXz1/zSdKV\nXSQRCnaRRCjYRRKhYBdJhIJdJBEKdpFEKNhFEtHT4hX19SVMHdwT7GPOfdyRt/z77OUb/QUk57b5\nr3PbXt4cbJ/d5Y9TuhBTvCLc58oWP2/g9L3he/UAsPPozW6f6rrwcbGIy0PMqjED0+HCE2e+5I9z\nw7Hr3D7F6Yrb5/194ecxpniFt1IL4BeeAICt/xreztm7bwq2155un5OhK7tIIhTsIolQsIskQsEu\nkggFu0giFOwiiVCwiyRCwS6SiJ4WryD5PoA/LHloM4ALPZtA59bSfNfSXIG1Nd/VPNcbzOyalT96\nGuwfGZwcM7P9fZvAx7SW5ruW5gqsrfmupbkupbfxIolQsIskot/BfqTP439ca2m+a2muwNqa71qa\n64f6+ju7iPROv6/sItIjCnaRRPQt2EkeJPkWybdJPtSvecQg+R7JN0i+TnKs3/O5GsnHSJ4neWLJ\nYxtJvkjyVOt/v9JDj7SZ73dJnm0d49dJfqWfc/wAyV0kf0XyJMk3ST7QenzVHt92+hLsJLMAfgzg\nywD2AbiH5L5+zOVj+LyZ3bpK768+DuDgVY89BOC4me0BcLz1/WrxOD46XwD4YesY32pmz/d4Tu3U\nAXzbzPYC+CyA+1rn6mo+vtfUryv7bQDeNrN3zKwK4GkAd/RpLmuemb0E4Oq1qu4A8ETr6ycA3NnL\nOYW0me+qZGYTZvZa6+tZACcBXI9VfHzb6VewXw/gzJLvx1uPrVYG4AWSr5I83O/JRNpqZhPA4gkL\nYLTP84lxP8nftd7mr7q3xSRvBPAZAK9gDR7ffgX7tarqreZ7gLeb2V9g8deO+0j+db8n9An0EwA3\nAbgVwASA7/d1NlchOQzgFwAeNLOZfs9nOfoV7OMAdi35fieAc32ai8vMzrX+Pw/gGSz+GrLaTZLc\nDgCt/8PLg/aZmU2aWcPMmgB+hlV0jEnmsRjoT5rZL1sPr6njC/Qv2H8DYA/JT5EsAPgqgGN9mksQ\nySGS6z74GsAXAZwI/9SqcAzAodbXhwA828e5uD4InJa7sEqOMUkCeBTASTP7wZKmNXV8gT5m0LVu\nrfwIQBbAY2b2T32ZiIPkn2Dxag4s1tn/+WqbK8mnABzA4p9eTgJ4GMC/ADgKYDeA0wDuNrNV8aFY\nm/kewOJbeAPwHoBvfPA7cT+R/CsA/wHgDQDN1sPfweLv7avy+LajdFmRRCiDTiQRCnaRRCjYRRKh\nYBdJhIJdJBEKdpFEKNhFEvF/ogWtLgN6/68AAAAASUVORK5CYII=\n",
            "text/plain": [
              "\u003cFigure size 600x400 with 1 Axes\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "plt.imshow(Kernel.kernel[:25, :25])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "executionInfo": {
          "elapsed": 103100,
          "status": "ok",
          "timestamp": 1618254166428,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "w-10BRjrU44t",
        "outputId": "8dfb9029-d199-4257-93f0-103d99c9581f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "n_train = 500, corruption = speckle_noise\n",
            "Corruption Level 0, Acc: 0.5024\n",
            "Corruption Level 1, Acc: 0.2831\n",
            "Corruption Level 2, Acc: 0.2046\n",
            "Corruption Level 3, Acc: 0.1831\n",
            "Corruption Level 4, Acc: 0.1566\n",
            "Corruption Level 5, Acc: 0.1418\n"
          ]
        }
      ],
      "source": [
        "corruption_name = 'speckle_noise'  #@param ['brightness', 'contrast', 'defocus_blur', 'elastic', 'fog', 'frost', 'frosted_glass_blur', 'gaussian_blur', 'gaussian_noise', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur']\n",
        "\n",
        "n_train = 500  #@param\n",
        "print(f'n_train = {n_train}, corruption = {corruption_name}')\n",
        "for c_level in range(0, 6):\n",
        "  if c_level == 0:\n",
        "    Kernel.update_test(f'clean')\n",
        "  else:\n",
        "    Kernel.update_test(f'{corruption_name}_{c_level}')\n",
        "  preds, regs, accs = run_kerenel_inference(Kernel, n_train)\n",
        "  print(f'Corruption Level {c_level}, Acc: {np.max(accs)}')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "Myrtle kernels for all CIFAR-10 corruption (Infinite uncertainty)",
      "provenance": [
        {
          "file_id": "1DNoY9i-WKsmt486Q0Xr4ycJXeS-NrSz1",
          "timestamp": 1615917402756
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
